#ifdef BUILD_NNET

#include "operators/membound.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"

namespace infini {

MemBoundObj::MemBoundObj(GraphObj *graph, const TensorVec &input,
                         const TensorVec &output,
                         const std::vector<nnet::Tensor> &nnetInputs,
                         nnet::Expr expr, double exec_time, std::string hint)
    : OperatorObj(OpType::MemBound, input, output), nnetInputs(nnetInputs),
      expr(expr), exec_time(exec_time), hint(hint) {
    IT_ASSERT(checkValid(graph));
    IT_ASSERT(!checkOOB(expr));
    hash = calcHash(expr);

    // fuse stages in nnet expr to reduce kernels generated by TVM
    if (auto mergedExpr =
            nnet::MergeMemboundMutator({expr}).merge(false, true)) {
        simplifiedExpr = mergedExpr;
        IT_ASSERT(!checkOOB(simplifiedExpr));
        simplifiedHash = calcHash(simplifiedExpr);
    } else {
        simplifiedExpr = expr;
        simplifiedHash = hash;
    }
}

string MemBoundObj::toString() const {
    std::ostringstream os;
    os << "MemBound[" << getGuid() << "](";
    for (size_t i = 0; i < inputs.size(); ++i) {
        os << "i" << i << "=" << inputs[i]->getGuid();
        if (i != inputs.size() - 1)
            os << " ";
    }
    os << ", ";
    for (size_t i = 0; i < outputs.size(); ++i) {
        os << "o" << i << "=" << outputs[i]->getGuid();
        if (i != outputs.size() - 1)
            os << " ";
    }
    os << ", ";
    os << "exec_time=" << exec_time << ", ";
    os << "NNet Inputs=[";
    for (const auto &tensor : nnetInputs)
        os << tensor->toReadable() << ",";
    os << "]";
    os << ", ExprHash=" << hash;
    os << ", SimplifiedExprHash=" << simplifiedHash;
    os << ")\n";
    os << ">>> Original expr\n"
       << (expr ? expr->toReadable() : "Empty expression") << "\n";
    os << ">>> Simplified expr\n"
       << (simplifiedExpr ? simplifiedExpr->toReadable() : "Empty expression")
       << "\n";
    return os.str();
}

optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) {
    // inputs have to match nnetInputs excatly
    if (inputs.size() != nnetInputs.size())
        return {};
    for (size_t i = 0; i < inputs.size(); ++i)
        if (inputs[i]->getDims() != nnetInputs[i]->getShape())
            return {};
    return {{nnet::as<nnet::RangeOpNode>(expr)->getOutputShape()}};
}

vector<int> MemBoundObj::getWorkloadVector() const {
    return {type.underlying(), (int)simplifiedHash};
}

vector<int> MemBoundObj::getOpAttrVector() const { return getWorkloadVector(); }

HashType MemBoundObj::calcHash(nnet::Expr expr) {
    return nnet::HashVisitor().dispatch(expr);
}

bool MemBoundObj::checkOOB(nnet::Expr expr) {
    return nnet::CheckOOBVisitor().checkRangeOp(
        nnet::as<nnet::RangeOpNode>(expr));
}

} // namespace infini

#endif
